
# Introduction ======

# Code for plotting Figure 2, the PSD data


#Load up ----

pacman::p_load(tidyverse,patchwork,ghibli,eegUtils,
               emmeans,mgcv,lme4,broom.mixed,
               tidybayes,bayestestR)


#Get our default settings
source("./eLife Submission Scripts/Analysis-Common-Utilities.R")


#Colours for the annotation patches
p_cols <- c("neg" = ghibli_palette("KikiLight")[2],
            "pos" = ghibli_palette("KikiMedium")[4])


# Load the processed data from Analysis-Figure-2-Stats.R
d_psd = readr::read_rds("./elife Submission Data/sleep_study_psd_bootci.rds")
d     = readr::read_rds("./elife Submission Data/sleep_study_psd_clusters.rds")


## Fig 2A-C =====

#The PSDs for the welch and z scored signals, and the fractal bit of th e IRASA signal
p_1 = 
  d_psd |>
  select(group,frequency, stage,dataset,mu,ci) |>
  unnest(ci) |>
  filter(dataset %in% c("welch","welch_z","frac")) |>
  ggplot(aes(x = frequency, y = mu, colour = group, fill = group)) +
  geom_line(size = 0.5) +
  geom_ribbon(alpha = 0.4, aes(ymin = ci.low,ymax = ci.high),colour = NA) +
  facet_grid(stage ~ dataset) +
  geom_rect(data = d|>
              select(-data) |>
              filter(dataset %in% c("welch","welch_z","frac")) |>
              unnest(clusters) |>
              mutate(c.min     = map_dbl(frequencies,min),
                     c.max     = map_dbl(frequencies,max),
                     frequency = map_dbl(frequencies,median),
                     mu = 0) |> 
              filter(dataset != "logo"), 
            aes(xmin = c.min,xmax = c.max,group = nc,fill = type),ymin = -Inf, ymax = Inf, alpha = 0.3,color = NA) +
  scale_fill_manual(values = c(cols,p_cols)) +
  scale_colour_manual(values = cols) +
  theme_bw() +
  theme(strip.background = element_blank(),
        panel.grid = element_blank(),
        strip.text = element_text(colour = "grey20",size = 8),
        axis.text.y = element_text(colour = "grey20",size = 8),
        axis.text.x = element_text(colour = "grey20",size = 8),
        axis.title = element_text(colour = "black",size = 8),
        legend.title = element_blank(),
        legend.justification = "top",
        legend.position = "none" ) +
  labs(x = "Frequency", y = "Power (dB)", title = "PSD Group Comparison", subtitle = "Cz")

## Fig 2D =====

#The IRASA oscillatory data
p_2 = 
  d_psd |>
  select(group,frequency, stage,dataset,mu,ci) |>
  unnest(ci) |>
  filter(dataset == "logo") |>
  ggplot(aes(x = frequency, y = mu, colour = group, fill = group)) +
  geom_line(size = 0.5) +
  geom_ribbon(alpha = 0.4, aes(ymin = ci.low,ymax = ci.high),colour = NA) +
  facet_grid(stage ~ dataset) +
  geom_rect(data = d|>
              select(-data) |>
              unnest(clusters) |>
              mutate(c.min     = map_dbl(frequencies,min),
                     c.max     = map_dbl(frequencies,max),
                     frequency = map_dbl(frequencies,median),
                     mu = 0) |> 
              filter(dataset == "logo"), 
            aes(xmin = c.min,xmax = c.max,group = nc,fill = type),ymin = -Inf, ymax = Inf, alpha = 0.3,color = NA) +
  scale_fill_manual(values = c(cols,p_cols)) +
  scale_colour_manual(values = cols) +
  theme_bw() +
  theme(strip.background = element_blank(),
        strip.text.x = element_blank(),
        panel.grid = element_blank(),
        strip.text = element_text(colour = "grey20",size = 8),
        axis.text.y = element_text(colour = "grey20",size = 8),
        axis.text.x = element_text(colour = "grey20",size = 8),
        axis.title = element_text(colour = "black",size = 8),
        legend.title = element_blank(),
        legend.justification = "top",
        legend.position = "none" ) +
  labs(x = "Frequency", y = "Oscillatory Residual (AU)")



#Combine
p_psd = 
  p_1 + p_2 + plot_layout(widths = c(3,1))

# p_psd

#Note we don't include the REM panel in the paper as detecting oscillations in REM does
#not make scientific sense. However, we keep the REM panel here in order to make the figure
#assemble with the correct layout ratio

### Save these plots =====

ggsave("./Figures/figure_2_psds.pdf",plot = p_psd, width = 16, height = 12, units = "cm")



# Figure 2E-G =======

## Load the GAMM output data ==== 

# To keep the plot simple, we plot topo difference plot in the main figure
#and put the individual group plots into the supplement

#Load the topoplot plotting data
d_gamm = read_rds("./eLife Submission Data/sleep_study_topoplot_posterior_data.rds")

spectral_details <- 
  c(meanSOP     = "SO Power",
    meanSigmaP  = "Sigma Power",
    peakSigmaF  = "Sigma Frequency",
    cons        = "Intercept",
    beta        = "Slope")



topo_diff =
  d_gamm |>
  select(stage,measure,post_draws) |>
  filter(measure %in% c("meanSigmaP","peakSigmaF","meanSOP","cons","beta" ) ) |>
  mutate(measure = factor(measure,levels = c("meanSOP","meanSigmaP","peakSigmaF","cons","beta"))) |>
  unnest(post_draws) |>
  mutate(incircle = sqrt(x ^ 2 + y ^ 2) < circ_scale) %>%
  filter(incircle) %>%
  ggplot(aes(x = x, y = y, fill = diff)) +
  geom_raster(aes(alpha = pv)) +
  geom_mask(r = circ_scale, size = 0.5) +
  geom_head(r = circ_scale,size = 0.5) +                  
  scale_fill_distiller(palette = "RdBu", 
                       limits = c(-1,1),
                       oob = scales::squish) +
  coord_equal()+
  theme_void() +
  scale_alpha(range = c(0.1, 1)) +
  facet_wrap(~stage+ measure,ncol = 5,
             labeller = labeller(measure = spectral_details),
             strip.position = "top") +
  labs(fill = "Group Difference") +
  theme(legend.position = "bottom",
        legend.direction = "horizontal",
        legend.title = element_text(colour = "black",size = 8),
        legend.text = element_text(colour = "black",size = 6,angle = 40),
        strip.text = element_text(colour = "black",size = 8,angle = 0)) +
  guides(alpha = "none")


# topo_diff


#Save this
ggsave("./Figures/figure_2_topo_gamm.pdf",plot = topo_diff, width = 24, height = 12, units = "cm")



# Figure 2 - Supplement 1: Individual Data ======

## Fig 2 - S1A: Individual PSDs =====

#Load PSD dataset
d = readr::read_rds("./eLife Submission Data/sleep_study_psd_data.rds")

#Select the data we want
d_irasa = 
  d |> 
  filter(dataset == "logo" & stage == "N2" & location == "Cz") 


#Plot each individuals PSD
p_individual_1 = 
  d_irasa |>
  select(subject,group,frequency, power) |>
  ggplot(aes(x = frequency, y = power, colour = group, fill = group)) +
  geom_line(size = 0.5) +
  facet_wrap(~subject,ncol = 7) +
  scale_fill_manual(values = cols) +
  scale_colour_manual(values = cols) +
  theme_bw() +
  theme(strip.background = element_blank(),
        strip.text.x = element_blank(),
        panel.grid = element_blank(),
        strip.text = element_text(colour = "grey20",size = 8),
        axis.text.y = element_text(colour = "grey20",size = 8),
        axis.text.x = element_text(colour = "grey20",size = 8),
        axis.title = element_text(colour = "black",size = 8),
        legend.title = element_blank(),
        legend.justification = "top",
        legend.position = "none" ) +
  labs(x = "Frequency", y = "Oscillatory Residual (AU)")



## Fig 2 - S1B: Make a plot of individual data as dots and boxplot  =====

#Load
d_sum = 
  readr::read_rds("./eLife Submission Data/sleep_study_eeg_summary_data.rds")

#Select only the data we are interested in for this figure
d_sum =
  d_sum |>
  filter(measure %in% c("meanSOP","meanSigmaP","peakSigmaF","cons","beta"))

#Same vibe as figure 1
p_individual_2 = 
  d_sum |>
  select(stage,measure,data) |>
  unnest(data) |>
  filter(electrode == "Cz") |>
  select(-c(x,y,electrode)) |>
  group_by(measure,stage) |>
  nest() |>
  mutate(measure = factor(measure,levels = c("meanSOP","meanSigmaP","peakSigmaF","cons","beta"))) |>
  arrange(stage,measure) |>
  mutate(m_s = interaction(measure,stage)) |>
  mutate(plot = map2(data,m_s, ~ggplot(data = .x, aes(x = group,y = value, fill = group,colour = group)) +
                       geom_boxplot(alpha = 0.2,
                                    lwd = 0.25,
                                    outlier.color = NA,
                                    outlier.fill = NA) +
                       geom_point(size = 0.5,
                                  position = ggforce::position_jitternormal(sd_x = 0.05, sd_y = 0),alpha = 0.6) +
                       scale_fill_manual(values = cols) +
                       scale_colour_manual(values = cols) +
                       theme_bw() +
                       theme(strip.background = element_blank(),
                             strip.text.x = element_blank(),
                             panel.grid = element_blank(),
                             strip.text = element_text(colour = "grey20",size = 8),
                             axis.text.y = element_text(colour = "grey20",size = 8),
                             axis.text.x = element_text(colour = "grey20",size = 8),
                             axis.title = element_text(colour = "black",size = 8),
                             legend.title = element_blank(),
                             legend.justification = "top",
                             legend.position = "none" ) +
                       labs(x = "Group", y = .y))) |>
  pull(plot) |>
  wrap_plots(ncol = 5)


## Fig 2 - S1C: Scatter with age for a figure supplement =======

p_individual_3 = 
  d_sum |>
  select(stage,measure,data) |>
  unnest(data) |>
  filter(electrode == "Cz") |>
  select(-c(x,y,electrode)) |>
  group_by(measure,stage) |>
  nest() |>
  mutate(measure = factor(measure,levels = c("meanSOP","meanSigmaP","peakSigmaF","cons","beta"))) |>
  arrange(stage,measure) |>
  mutate(m_s = interaction(measure,stage)) |>
  mutate(plot = map2(data,m_s, ~ggplot(data = .x, aes(x = age_eeg,y = value, fill = group,colour = group)) +
                       geom_point(size = 0.5,alpha = 0.6) +
                       geom_smooth(size = 0.5,
                                   formula = y ~ x,method = "lm") +
                       scale_fill_manual(values = cols) +
                       scale_colour_manual(values = cols) +
                       theme_bw() +
                       theme(strip.background = element_blank(),
                             strip.text.x = element_blank(),
                             panel.grid = element_blank(),
                             strip.text = element_text(colour = "grey20",size = 8),
                             axis.text.y = element_text(colour = "grey20",size = 8),
                             axis.text.x = element_text(colour = "grey20",size = 8),
                             axis.title = element_text(colour = "black",size = 8),
                             legend.title = element_blank(),
                             legend.justification = "top",
                             legend.position = "none" ) +
                       labs(x = "Age", y = .y))) |>
  pull(plot) |>
  wrap_plots(ncol = 5)

#Assemble

p_individual = p_individual_1 / p_individual_2 / p_individual_3 + plot_annotation(tag_levels = "A")

## Save =====

ggsave("./Figures/figure_2_supplement_1.pdf",plot = p_individual, width = 16, height = 30, units = "cm")

#And then we do the final alignment in Inkscape


# Figure 2 - Supplement 2: Group Topos -----


#Topoplots of the average data for each group

d_topo = read_rds("./eLife Submission Data/sleep_study_eeg_summary_data.rds")

spectral_details <- 
  c(meanSOP     = "SO Power",
    meanSigmaP  = "Sigma Power",
    peakSigmaF  = "Sigma Frequency",
    cons        = "Intercept",
    beta        = "Slope")

#Do a topoplot of the raw means
topo_psd = 
  d_topo |>
  #Kick out REM peak sigma frequency because it doesn't make sense
  filter(stage %in% c("N2","N3") | measure %in% c("beta","cons")) |>
  mutate(measure = factor(measure,levels = c("meanSOP","meanSigmaP","peakSigmaF","cons","beta"))) |>
  arrange(measure) |>
  ungroup() |>
  unnest(data) |>
  group_by(measure) |>
  nest () |>
  mutate(plot = map2(data,measure, ~ggplot(data = .x,
                                           aes(x = x,
                                               y = y,
                                               z = value,
                                               fill = value,
                                               label = electrode)) +
                       geom_topo(grid_res = 200,
                                 colour = "white",
                                 size = 0.1,
                                 interp_limit = "head",
                                 chan_markers = "point",
                                 chan_size = 0.25,
                                 head_size = 0.5,
                                 method = "gam", breaks = 10) + 
                       scale_fill_viridis_c(option = "H")+
                       facet_wrap(~stage+group, ncol = 1)+
                       theme_void() + 
                       coord_equal() + 
                       labs(subtitle = .y, fill = .y) +
                       theme(plot.subtitle = element_text(colour = "black",size = 8) ,
                             legend.position = "bottom",
                             legend.title = element_blank(),
                             legend.text = element_text(colour = "black",size = 6,angle = 40),
                             strip.text  = element_text(colour = "black",size = 8,angle = 0))) ) 



tp_s2 = 
  (topo_psd$plot[[1]] + plot_spacer()) + plot_layout(heights = c(2,1)) |
  (topo_psd$plot[[2]] + plot_spacer()) + plot_layout(heights = c(2,1)) |
  (topo_psd$plot[[3]] + plot_spacer()) + plot_layout(heights = c(2,1)) |
  (topo_psd$plot[[4]]) |
  (topo_psd$plot[[5]])

# tp_s2

## Save ====

ggsave("./Figures/figure_2_supplement_2.pdf",plot = tp_s2, width = 24, height = 24, units = "cm")


#Final assembly and layout in Inkscape